{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "start_on_colab.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyPIpbiKdUC+5q3/Vlm8dk4L"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "PwHOLr9y7S3Q"
      },
      "source": [
        "from IPython.display import HTML, display\n",
        "\n",
        "def set_css():\n",
        "  display(HTML('''\n",
        "  <style>\n",
        "    pre {\n",
        "        white-space: pre-wrap;\n",
        "    }\n",
        "  </style>\n",
        "  '''))\n",
        "get_ipython().events.register('pre_run_cell', set_css)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zXkQlZUJ8Lnm"
      },
      "source": [
        "!git clone --recursive https://github.com/chenyaofo/image-classification-codebase\n",
        "%cd image-classification-codebase\n",
        "\n",
        "%pip install -qr requirements.txt\n",
        "\n",
        "import torch\n",
        "from IPython.display import clear_output\n",
        "\n",
        "clear_output()\n",
        "print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ijy9oBmN6ovz"
      },
      "source": [
        "The following command list all the available models."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 69
        },
        "id": "y-AG3Ne76dOv",
        "outputId": "5607e9cb-31aa-48d6-a465-65662eebb653"
      },
      "source": [
        "import torch\n",
        "print(torch.hub.list(\"chenyaofo/pytorch-cifar-models\", force_reload=True))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "  <style>\n",
              "    pre {\n",
              "        white-space: pre-wrap;\n",
              "    }\n",
              "  </style>\n",
              "  "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Downloading: \"https://github.com/chenyaofo/pytorch-cifar-models/archive/master.zip\" to /root/.cache/torch/hub/master.zip\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "['cifar100_resnet20', 'cifar100_resnet32', 'cifar100_resnet44', 'cifar100_resnet56', 'cifar100_vgg11_bn', 'cifar100_vgg13_bn', 'cifar100_vgg16_bn', 'cifar100_vgg19_bn', 'cifar10_resnet20', 'cifar10_resnet32', 'cifar10_resnet44', 'cifar10_resnet56', 'cifar10_vgg11_bn', 'cifar10_vgg13_bn', 'cifar10_vgg16_bn', 'cifar10_vgg19_bn']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4jofqHNwAgvS"
      },
      "source": [
        "To train/evaluate different models, please set `model.model_name` to available model names listed in the previous block."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YA5-PFL8BPUq"
      },
      "source": [
        "**Train on CIFAR-10:**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BM_9IlZ72D9l"
      },
      "source": [
        "!python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M model.name=cifar10_resnet20"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QwrIVjWQNV82"
      },
      "source": [
        "**Evaluate on CIFAR-10:**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DMd_USWH81Ov"
      },
      "source": [
        "!python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M model.name=cifar10_resnet20 model.pretrained=true only_evaluate=true"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aKYW0yNINoSk"
      },
      "source": [
        "**Train on CIFAR-100:**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m2bVfrEjNoHY"
      },
      "source": [
        "!python -m entry.run --conf conf/cifar100.conf -o output/cifar100/resnet20 -M model.name=cifar100_resnet20"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PgcF8E8NNxgp"
      },
      "source": [
        "**Evaluate on CIFAR-100:**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5k-sCfhqNxDc"
      },
      "source": [
        "!python -m entry.run --conf conf/cifar100.conf -o output/cifar100/resnet20 -M model.name=cifar100_resnet20 model.pretrained=true only_evaluate=true"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}